import os
import numpy as np
from prodiapy import Prodia
import gradio as gr
import json
import requests
import base64
import random
import time 

STYLE_PRESETS = [None, "3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly",
                 "neon-punk", "origami", "photographic", "pixel-art", "texture", "craft-clay"]
MAX_SEED = np.iinfo(np.int32).max


class Prodia:
    def __init__(self, api_key=os.getenv("PRODIA_API_KEY"), base=None):
        self.base = base or "https://api.prodia.com/v1"
        self.headers = {
            "X-Prodia-Key": api_key
        }
    
    def photomaker(self, params):
        print(params)
        response = self._post(f"{self.base}/photomaker", params)
        return response.json()
    
    def get_job(self, job_id):
        response = self._get(f"{self.base}/job/{job_id}")
        return response.json()

    def wait(self, job):
        job_result = job

        while job_result['status'] not in ['succeeded', 'failed']:
            time.sleep(0.25)
            job_result = self.get_job(job['job'])

        return job_result

    def _post(self, url, params):
        headers = {
            **self.headers,
            "Content-Type": "application/json"
        }
        response = requests.post(url, headers=headers, data=json.dumps(params))

        if response.status_code != 200:
            raise Exception(f"Bad Prodia Response: {response.status_code}")

        return response

    def _get(self, url):
        response = requests.get(url, headers=self.headers)

        if response.status_code != 200:
            raise Exception(f"Bad Prodia Response: {response.status_code}")

        return response

    
client = Prodia()


def generate_image(upload_images, prompt, negative_prompt, style_preset, steps, cfg_scale, strength, seed, progress=gr.Progress(track_tqdm=True)):
    error_if_no_img(prompt)
    print(upload_images)

    params = {
        "imageData": [file_to_base64(img) for img in upload_images],
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "steps": steps, 
        "cfg_scale": cfg_scale,
        "strength": strength,
        "seed": seed if seed != 0 else random.randint(1, MAX_SEED)
    }

    if style_preset is not None and style_preset in STYLE_PRESETS:
        params['style_preset'] = style_preset

    job = client.photomaker(params)              
    res = client.wait(job)

    if res['status'] == "failed":
        return
    
    return res['imageUrl']


def error_if_no_img(prompt):
    if "img" not in prompt:
        raise gr.Error("Prompt must contain 'img'")


def swap_to_gallery(images):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)


def upload_example_to_gallery(images, prompt, style, negative_prompt):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)


def remove_back_to_files():
    return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)


def get_image_path_list(folder_name):
    image_basename_list = os.listdir(folder_name)
    image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
    return image_path_list


def file_to_base64(file_path):
    with open(file_path, "rb") as file:
        file_data = file.read()
        base64_string = base64.b64encode(file_data).decode('utf-8')
        return base64_string


def get_example():
    case = [
        [
            get_image_path_list('./examples/scarletthead_woman'),
            "instagram photo, portrait photo of a woman img , colorful, perfect face, natural skin, hard shadows, film grain",
            None,
            "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth",
        ],
        [
            get_image_path_list('./examples/newton_man'),
            "sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain",
            None,
            "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth",
        ],
    ]
    return case


title = r"""
<h1 align="center">PhotoMaker: Generate images with facial consistency to input images</h1>
"""

css = '''
.gradio-container {width: 85% !important}
'''
with gr.Blocks(css=css) as demo:
    gr.Markdown(title)

    with gr.Row():
        with gr.Column():
            files = gr.File(
                label="Drag (Select) 1 or more photos of your face",
                file_types=["image"],
                file_count="multiple"
            )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                                info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.",
                                placeholder="A photo of a [man/woman img]...")
            style = gr.Dropdown(label="Style template", choices=STYLE_PRESETS, value=None)
            submit = gr.Button("Submit")

            with gr.Accordion(open=False, label="Advanced Options"):
                negative_prompt = gr.Textbox(
                    label="Negative Prompt", 
                    placeholder="low quality",
                    value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
                )
                with gr.Row():
                    steps = gr.Slider(
                        label="Number of sample steps",
                        minimum=20,
                        maximum=50,
                        step=1,
                        value=40,
                    )
                    cfg_scale = gr.Slider(
                        label="CFG Scale",
                        minimum=5,
                        maximum=20,
                        value=7,
                    )
                with gr.Row():
                    strength_ratio = gr.Slider(
                        label="Strength (%)",
                        minimum=15,
                        maximum=50,
                        step=1,
                        value=20,
                    )
                    seed = gr.Slider(
                        label="Seed",
                        minimum=0,
                        maximum=MAX_SEED,
                        step=1,
                        value=0,
                    )
        with gr.Column():
            result_image = gr.Image(label="Generated Image")

        files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
        remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])

        submit.click(
            fn=generate_image,
            inputs=[files, prompt, negative_prompt, style, steps, cfg_scale, strength_ratio, seed],
            outputs=[result_image]
        )

    gr.Examples(
        examples=get_example(),
        inputs=[files, prompt, style, negative_prompt],
        run_on_click=True,
        fn=upload_example_to_gallery,
        outputs=[uploaded_files, clear_button, files],
    )
    
demo.queue(max_size=20).launch(show_api=False)