import logging
from typing import Any, Optional
import torch
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    EulerDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    LCMScheduler,
    Transformer2DModel,
    UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from models.RewardPixart import RewardPixartPipeline, freeze_params
from models.RewardStableDiffusion import RewardStableDiffusion
from models.RewardStableDiffusionXL import RewardStableDiffusionXL
from models.RewardFlux import RewardFluxPipeline


def get_model(
    model_name: str,
    dtype: torch.dtype,
    device: torch.device,
    cache_dir: str,
    memsave: bool = False,
    enable_sequential_cpu_offload: bool = False,
):
    logging.info(f"Loading model: {model_name}")
    if model_name == "sd-turbo":
        pipe = RewardStableDiffusion.from_pretrained(
            "stabilityai/sd-turbo",
            torch_dtype=dtype,
            variant="fp16",
            cache_dir=cache_dir,
            memsave=memsave,
        )
        #pipe = pipe.to(device, dtype)
    elif model_name == "sdxl-turbo":
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
        )
        pipe = RewardStableDiffusionXL.from_pretrained(
            "stabilityai/sdxl-turbo",
            vae=vae,
            torch_dtype=dtype,
            variant="fp16",
            use_safetensors=True,
            cache_dir=cache_dir,
            memsave=memsave,
        )
        pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
            pipe.scheduler.config, timestep_spacing="trailing"
        )
        #pipe = pipe.to(device, dtype)
    elif model_name == "pixart":
        pipe = RewardPixartPipeline.from_pretrained(
            "PixArt-alpha/PixArt-XL-2-1024-MS",
            torch_dtype=dtype,
            cache_dir=cache_dir,
            memsave=memsave,
        )
        pipe.transformer = Transformer2DModel.from_pretrained(
            "PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512",
            subfolder="transformer",
            torch_dtype=dtype,
            cache_dir=cache_dir,
        )
        pipe.scheduler = DDPMScheduler.from_pretrained(
            "PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512",
            subfolder="scheduler",
            cache_dir=cache_dir,
        )

        # speed-up T5
        pipe.text_encoder.to_bettertransformer()
        pipe.transformer.eval()
        freeze_params(pipe.transformer.parameters())
        pipe.transformer.enable_gradient_checkpointing()
        #pipe = pipe.to(device)
    
    elif model_name == "hyper-sd":
        base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
        repo_name = "ByteDance/Hyper-SD"
        ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
    
        # Load model but don't specify device or dtype (defaults to CPU and float32)
        unet = UNet2DConditionModel.from_config(
            base_model_id, subfolder="unet", cache_dir=cache_dir
        )
    
        # Load state dict into unet (stays on CPU by default)
        unet.load_state_dict(
            load_file(
                hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir),
                device="cuda",
            )
        )
    
        # Initialize the pipeline (it will stay on CPU initially, using default dtype)
        pipe = RewardStableDiffusionXL.from_pretrained(
            base_model_id,
            unet=unet,
            torch_dtype=torch.float16,
            variant="fp16",  # Still set fp16 for later use on GPU
            cache_dir=cache_dir,
            is_hyper=True,
            memsave=memsave,
        )

        # Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs
        pipe.scheduler = LCMScheduler.from_config(
            pipe.scheduler.config, cache_dir=cache_dir
        )

    elif model_name == "flux":
        pipe = RewardFluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
        )
        #pipe.to(device, dtype)
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    #if enable_sequential_cpu_offload:
    #    pipe.enable_sequential_cpu_offload()
    return pipe


def get_multi_apply_fn(
    model_type: str,
    seed: int,
    pipe: Optional[Any] = None,
    cache_dir: Optional[str] = None,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    generator = torch.Generator("cuda").manual_seed(seed)
    if model_type == "flux":
        return lambda latents, prompt: torch.no_grad(pipe.apply)(
            latents=latents,
            prompt=prompt,
            num_inference_steps=4,
            generator=generator,
        )
    elif model_type == "sdxl":
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
            cache_dir=cache_dir,
        )
        pipe = RewardStableDiffusionXL.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            variant="fp16",
            vae=vae,
            use_safetensors=True,
            cache_dir=cache_dir,
        )
        pipe = pipe.to(device, dtype)
        pipe.enable_sequential_cpu_offload()
        return lambda latents, prompt: torch.no_grad(pipe.apply)(
            latents=latents,
            prompt=prompt,
            guidance_scale=5.0,
            num_inference_steps=50,
            generator=generator,
        )
    elif model_type == "sd2":
        sd2_base = "stabilityai/stable-diffusion-2-1-base"
        scheduler = EulerDiscreteScheduler.from_pretrained(
            sd2_base,
            subfolder="scheduler",
            cache_dir=cache_dir,
        )
        pipe = RewardStableDiffusion.from_pretrained(
            sd2_base,
            torch_dtype=dtype,
            cache_dir=cache_dir,
            scheduler=scheduler,
        )
        pipe = pipe.to(device, dtype)
        pipe.enable_sequential_cpu_offload()
        return lambda latents, prompt: torch.no_grad(pipe.apply)(
            latents=latents,
            prompt=prompt,
            guidance_scale=7.5,
            num_inference_steps=50,
            generator=generator,
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")