import os
import logging
import torch
import asyncio
import aiohttp
import requests
from huggingface_hub import hf_hub_download

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configuration
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hugging Face repository information
HF_REPO_ID = "jbilcke-hf/model-cocktail"

# Model files to download
MODEL_FILES = [
    "dwpose/dw-ll_ucoco_384.pth",
    "face-detector/s3fd-619a316812.pth",

    "liveportrait/spade_generator.pth",
    "liveportrait/warping_module.pth",
    "liveportrait/motion_extractor.pth",
    "liveportrait/stitching_retargeting_module.pth",
    "liveportrait/appearance_feature_extractor.pth",
    "liveportrait/landmark.onnx",

    # For animal mode 🐢🐱
    # however they say animal mode doesn't support stitching yet?
    # https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode
    #"liveportrait-animals/warping_module.pth",
    #"liveportrait-animals/spade_generator.pth",
    #"liveportrait-animals/motion_extractor.pth",
    #"liveportrait-animals/appearance_feature_extractor.pth",
    #"liveportrait-animals/stitching_retargeting_module.pth",
    #"liveportrait-animals/xpose.pth",

    # this is a hack, instead we should probably try to
    # fix liveportrait/utils/dependencies/insightface/utils/storage.py
    "insightface/models/buffalo_l.zip",

    "insightface/buffalo_l/det_10g.onnx",
    "insightface/buffalo_l/2d106det.onnx",
    "sd-vae-ft-mse/diffusion_pytorch_model.bin",
    "sd-vae-ft-mse/diffusion_pytorch_model.safetensors",
    "sd-vae-ft-mse/config.json",

    # we don't use those yet
    #"flux-dev/flux-dev-fp8.safetensors",
    #"flux-dev/flux_dev_quantization_map.json",
    #"pulid-flux/pulid_flux_v0.9.0.safetensors",
    #"pulid-flux/pulid_v1.bin"
]

def create_directory(directory):
    """Create a directory if it doesn't exist and log its status."""
    if not os.path.exists(directory):
        os.makedirs(directory)
        logger.info(f"  Directory created: {directory}")
    else:
        logger.info(f"  Directory already exists: {directory}")

def print_directory_structure(startpath):
    """Print the directory structure starting from the given path."""
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * level
        logger.info(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            logger.info(f"{subindent}{f}")

async def download_hf_file(filename: str) -> None:
    """Download a file from Hugging Face to the models directory."""
    dest = os.path.join(MODELS_DIR, filename)
    os.makedirs(os.path.dirname(dest), exist_ok=True)
    if os.path.exists(dest):
        # this is really for debugging purposes only
        logger.debug(f"    βœ… {filename}")
        return

    logger.info(f"    ⏳ Downloading {HF_REPO_ID}/{filename}")

    try:
        await asyncio.get_event_loop().run_in_executor(
            None,
            lambda: hf_hub_download(
                repo_id=HF_REPO_ID,
                filename=filename,
                local_dir=MODELS_DIR
            )
        )
        logger.info(f"    βœ… Downloaded {filename}")
    except Exception as e:
        logger.error(f"🚨 Error downloading file from Hugging Face: {e}")
        if os.path.exists(dest):
            os.remove(dest)
        raise

async def download_all_models():
    """Download all required models from the Hugging Face repository."""
    logger.info("  πŸ”Ž Looking for models...")
    tasks = [download_hf_file(filename) for filename in MODEL_FILES]
    await asyncio.gather(*tasks)
    logger.info("  βœ… All models are available")

    # are you looking to debug the app and verify that models are downloaded properly?
    # then un-comment the two following lines:
    #logger.info("πŸ’‘ Printing directory structure of models:")
    #print_directory_structure(MODELS_DIR)

class ModelLoader:
    """A class responsible for loading and initializing all required models."""

    def __init__(self):
        self.device = DEVICE
        self.models_dir = MODELS_DIR

    async def load_live_portrait(self):
        """Load LivePortrait models."""
        from liveportrait.config.inference_config import InferenceConfig
        from liveportrait.config.crop_config import CropConfig
        from liveportrait.live_portrait_pipeline import LivePortraitPipeline

        logger.info("  ⏳ Loading LivePortrait models...")
        live_portrait_pipeline = await asyncio.to_thread(
            LivePortraitPipeline,
            inference_cfg=InferenceConfig(
                # default values
                flag_stitching=True,  # we recommend setting it to True!
                flag_relative=True,  # whether to use relative motion
                flag_pasteback=True,  # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
                flag_do_crop= True,  # whether to crop the source portrait to the face-cropping space
                flag_do_rot=True,  # whether to conduct the rotation when flag_do_crop is True
            ),
            crop_cfg=CropConfig()
        )
        logger.info("  βœ… LivePortrait models loaded successfully.")
        return live_portrait_pipeline

async def initialize_models():
    """Initialize and load all required models."""
    logger.info("πŸš€ Starting model initialization...")

    # Ensure all required models are downloaded
    await download_all_models()

    # Initialize the ModelLoader
    loader = ModelLoader()

    # Load LivePortrait models
    live_portrait = await loader.load_live_portrait()

    logger.info("βœ… Model initialization completed.")
    return live_portrait

# Initial setup
logger.info("πŸš€ Setting up storage directories...")
create_directory(MODELS_DIR)
logger.info("βœ… Storage directories setup completed.")