import base64
import datetime
import gradio as gr
import numpy as np
import os
import pytz
import psutil
import re
import random
import torch
import time
import shutil  # Added for zip functionality
import zipfile
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny

try:
    import intel_extension_for_pytorch as ipex
except:
    pass

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
    "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16

# Function to encode a file to base64
def encode_file_to_base64(file_path):
    with open(file_path, "rb") as file:
        encoded = base64.b64encode(file.read()).decode()
    return encoded

def create_zip_of_files(files):
    """
    Create a zip file from a list of files.
    """
    zip_name = "all_files.zip"
    with zipfile.ZipFile(zip_name, 'w') as zipf:
        for file in files:
            zipf.write(file)
    return zip_name


def get_zip_download_link(zip_file):
    """
    Generate a link to download the zip file.
    """
    with open(zip_file, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
    return href

# Function to clear all image files
def clear_all_images():
    base_dir = os.getcwd()  # Get the current base directory
    img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]  # List all files ending with ".jpg" or ".jpeg"
    
    # Remove all image files
    for file in img_files:
        os.remove(file)
        print('removed:' + file)
  
# add file save and download and clear:
# Function to create a zip file from a list of files
def create_zip(files):
    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    zip_filename = f"images_{timestamp}.zip"
    print('Creating file ' + zip_filename)
    with zipfile.ZipFile(zip_filename, 'w') as zipf:
        for file in files:
            zipf.write(file, os.path.basename(file))
            print('added:' + file)
    return zip_filename
    
# Function to save all images as a zip file and provide a base64 download link
def save_all_images(images):
    if len(images) == 0:
        return None, None
    zip_filename = create_zip(images)  # Create a zip file from the list of image files
    print(' Zip file created:' + zip_filename)

    gr.Button(link="/file=" + zip_filename)

    # remove?
    zip_base64 = encode_file_to_base64(zip_filename)  # Encode the zip file to base64
    download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All</a>'
    gr.HTML(download_link)
    
    # redirect_button = gr.Button("Clear", variant='secondary')
    # redirect_button.click(None, None,None, _js="window.location.assign('https://google.com');")
    
    return zip_filename, download_link
        
# Function to handle "Save All" button click
def save_all_button_click():
    images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
    zip_filename, download_link = save_all_images(images)   
    if zip_filename:
        print(zip_filename)
        gr.Button(link=zip_filename)
        gr.File(value=zip_filename)
    if download_link:
        print(download_link)
        gr.HTML(download_link)
        gr.Button(link=download_link)

# Function to handle "Clear All" button click
def clear_all_button_click():
    clear_all_images()

print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
print(f"device: {device}")

if mps_available:
    device = torch.device("mps")
    torch_device = "cpu"
    torch_dtype = torch.float32

if SAFETY_CHECKER == "True":
    pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7")
else:
    pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None)

pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.set_progress_bar_config(disable=True)

# check if computer has less than 64GB of RAM using sys or os
if psutil.virtual_memory().total < 64 * 1024**3:
    pipe.enable_attention_slicing()

if TORCH_COMPILE:
    pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
    pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)

# Load LCM LoRA
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
pipe.fuse_lora()

def safe_filename(text):
    """Generate a safe filename from a string."""
    safe_text = re.sub(r'\W+', '_', text)
    timestamp = datetime.datetime.now().strftime("%Y%m%d")
    return f"{safe_text}_{timestamp}.png"
    
def encode_image(image):
    """Encode image to base64."""
    buffered = BytesIO()
    #image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

def fake_gan():
    base_dir = os.getcwd()  # Get the current base directory
    img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]  # List all files ending with ".jpg" or ".jpeg"
    images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
    return images
    
def predict(prompt, guidance, steps, seed=1231231):
    generator = torch.manual_seed(seed)
    last_time = time.time()
    results = pipe(
        prompt=prompt,
        generator=generator,
        num_inference_steps=steps,
        guidance_scale=guidance,
        width=512,
        height=512,
        # original_inference_steps=params.lcm_steps,
        output_type="pil",
    )
    print(f"Pipe took {time.time() - last_time} seconds")
    nsfw_content_detected = (
        results.nsfw_content_detected[0]
        if "nsfw_content_detected" in results
        else False
    )
    if nsfw_content_detected:
        nsfw=gr.Button("🕹️NSFW🎨", scale=1)

    try: 
        central = pytz.timezone('US/Central')
        safe_date_time = datetime.datetime.now().strftime("%Y%m%d")
        replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
        safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
        filename = f"{safe_date_time}_{safe_prompt}.png"
        
        # Save the image
        if len(results.images) > 0:
            image_path = os.path.join("", filename)  # Specify your directory
            results.images[0].save(image_path)
            print(f"#Image saved as {image_path}")
            gr.File(image_path)
            gr.Button(link=image_path)
            # encoded_image = encode_image(image)
            # html_link = f'<a href="data:image/png;base64,{encoded_image}" download="{filename}">Download Image</a>'
            # gr.HTML(html_link)
    except:
        return results.images[0]

    return results.images[0] if len(results.images) > 0 else None


css = """
#container{
    margin: 0 auto;
    max-width: 40rem;
}
#intro{
    max-width: 100%;
    text-align: center;
    margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:

    with gr.Column(elem_id="container"):
        gr.Markdown(
            """4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""",
            elem_id="intro",
        )
        with gr.Row():
            with gr.Row():
                prompt = gr.Textbox(
                    placeholder="Insert your prompt here:", scale=5, container=False
                )
                generate_bt = gr.Button("Generate", scale=1)

        # Image Result from last prompt
        image = gr.Image(type="filepath")

        # Gallery of Generated Images with Image Names in Random Set to Download
        with gr.Row(variant="compact"):
            text = gr.Textbox(
                label="Image Sets",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
            )
            btn = gr.Button("Generate Gallery of Saved Images")
        gallery = gr.Gallery(
            label="Generated Images", show_label=False, elem_id="gallery"
        )

        with gr.Row(variant="compact"):
            # Add "Save All" button with emoji
            save_all_button = gr.Button("💾 Save All", scale=1)
            # Add "Clear All" button with emoji
            clear_all_button = gr.Button("🗑️ Clear All", scale=1)

        # Advanced Generate Options
        with gr.Accordion("Advanced options", open=False):
            guidance = gr.Slider(
                label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
            )
            steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
            seed = gr.Slider(
                randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
            )

        # Diffusers
        with gr.Accordion("Run with diffusers"):
            gr.Markdown(
                """## Running LCM-LoRAs it with `diffusers`
            ```bash
            pip install diffusers==0.23.0
            ```
            
            ```py
            from diffusers import DiffusionPipeline, LCMScheduler
            pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda") 
            pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
            pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
            results = pipe(
                prompt="ImageEditor",
                num_inference_steps=4,
                guidance_scale=0.0,
            )
            results.images[0]
            ```
            """
            )

        # Function IO Eventing and Controls
        inputs = [prompt, guidance, steps, seed]
        generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        btn.click(fake_gan, None, gallery)
        prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        guidance.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
        seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)

        # Attach click event handlers to the buttons
        save_all_button.click(save_all_button_click)

        with gr.Column():
            file_obj = gr.File(label="Input File")
            input= file_obj
                    
        clear_all_button.click(clear_all_button_click)

demo.queue()
demo.launch()