# This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_seg2image.py
# The original license file is LICENSE.ControlNet in this repo.
import gradio as gr
from PIL import Image

#first elem of gallery is ^^ - {'name': '/tmp/tmpw60bbw6k.png', 'data': 'file=/tmp/tmpw60bbw6k.png', 'is_file': True}
#first elem of gallery is ^^ - {'name': '/tmp/tmpba0d5dt5.png', 'data': 'file=/tmp/tmpba0d5dt5.png', 'is_file': True}

import numpy as np
import base64

def encode(img_array):
    print(f"type of input_image ^^ - {type(img_array)}")
    # Convert NumPy array to image
    img = Image.fromarray(img_array)

    # Save image to file
    img_path = "temp_image.jpeg"
    img.save(img_path) 

    # Encode image file using Base64
    with open(img_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")

    # Print and return the encoded string
    #print(encoded_string)
    return encoded_string

def create_imgcomp(input_image, result_image): #(input_image, filename):
    encoded_string_in = encode(input_image)
    encoded_string_out = encode(result_image)
    
    htmltag = '<img src= "data:image/jpeg;base64,' + encoded_string_in + '" alt="Original Image" height="500"/></div> <img src= "data:image/jpeg;base64,' + encoded_string_out + '" alt="Control Net Image" height="500"/>'
    #sample - htmltag = '<img src= "data:image/jpeg;base64,' + encoded_string + '" alt="Original Image"/></div> <img src= "https://ysharma-controlnet-image-comparison.hf.space/file=' + filename + '" alt="Control Net Image"/>'
    print(f"htmltag is ^^ - {htmltag}")
    desc = """
        <!DOCTYPE html>
        <html lang="en">
        <head>
        	<style>
        		body {
        			background: rgb(17, 17, 17);
        		}
        		
        		.image-slider {
        			margin-left: 3rem;
        			position: relative;
        			display: inline-block;
        			line-height: 0;
        		}
        		
        		.image-slider img {
        			user-select: none;
        			max-width: 400px;
        		}
        		
        		.image-slider > div {
        			position: absolute;
        			width: 25px;
        			max-width: 100%;
        			overflow: hidden;
        			resize: horizontal;
        		}
        		
        		.image-slider > div:before {
        			content: '';
        			display: block;
        			width: 13px;
        			height: 13px;
        			overflow: hidden;
        			position: absolute;
        			resize: horizontal;
        			right: 3px;
        			bottom: 3px;
        			background-clip: content-box;
        			background: linear-gradient(-45deg, black 50%, transparent 0);
        			-webkit-filter: drop-shadow(0 0 2px black);
        			filter: drop-shadow(0 0 2px black);
        		}
        	</style>
        </head>
        <body>
        	<div style="margin: 3rem;
        				font-family: Roboto, sans-serif">
        		</div> <div> <div class="image-slider"> <div> """ + htmltag + "</div> </div> </body> </html> "
    return desc



def dummyfun(result_gallery):
    print(f"type of gallery is ^^ - {type(result_gallery)}")
    print(f"length of gallery is ^^ - {len(result_gallery)}")
    print(f"first elem of gallery is ^^ - {result_gallery[0]}")
    print(f"first elem of gallery is ^^ - {result_gallery[1]}")
    # Load the image
    #image = result_gallery[1] #Image.open("example.jpg")
    
    # Get the filename
    #filename = image.filename
    
    # Print the filename
    #print(f"filename is ^^ - {filename}")
    return result_gallery[1]['name'] #+ ',' + result_gallery[1]['name'] #filename

def create_demo(process, max_images=12):
    with gr.Blocks(css = "#input_image {width: 512px;} #out_image {width: 512px;}") as demo:
        with gr.Row():
            gr.Markdown('## Control Stable Diffusion with Segmentation Maps')
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(source='upload', type='numpy', elem_id='input_image')
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False, visible=False):
                    num_samples = gr.Slider(label='Images',
                                            minimum=1,
                                            maximum=max_images,
                                            value=1,
                                            step=1)
                    image_resolution = gr.Slider(label='Image Resolution',
                                                 minimum=256,
                                                 maximum=768,
                                                 value=512,
                                                 step=256)
                    detect_resolution = gr.Slider(
                        label='Segmentation Resolution',
                        minimum=128,
                        maximum=1024,
                        value=512,
                        step=1)
                    ddim_steps = gr.Slider(label='Steps',
                                           minimum=1,
                                           maximum=100,
                                           value=20,
                                           step=1)
                    scale = gr.Slider(label='Guidance Scale',
                                      minimum=0.1,
                                      maximum=30.0,
                                      value=9.0,
                                      step=0.1)
                    seed = gr.Slider(label='Seed',
                                     minimum=-1,
                                     maximum=2147483647,
                                     step=1,
                                     randomize=True,
                                     queue=False)
                    eta = gr.Number(label='eta (DDIM)', value=0.0)
                    a_prompt = gr.Textbox(
                        label='Added Prompt',
                        value='best quality, extremely detailed')
                    n_prompt = gr.Textbox(
                        label='Negative Prompt',
                        value=
                        'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
                    )
            with gr.Column():
                #<h4 style="color: green"> Observe the Ingenuity of ControlNet by comparing Input and Output images</h4>
                #result_gallery = gr.Gallery(label='Output', visible= False, 
                #                            show_label=False,
                #                            elem_id='gallery').style(
                #                                grid=2, height='auto')
                result_image = gr.Image(visible=False).style(height='auto', type='numpy')
                #b1 = gr.Button('Get filenames')
                #filename = gr.Textbox(label="image file names", visible=False)
                #b2 = gr.Button('Show Image-Comparison')
                with gr.Box():
                    msg = gr.HTML()
                    imagecomp = gr.HTML() 
        ips = [
            input_image, prompt, a_prompt, n_prompt, num_samples,
            image_resolution, detect_resolution, ddim_steps, scale, seed, eta
        ]
        run_button.click(fn=process,
                         inputs=ips,
                         outputs=[result_image, msg], #[result_gallery, imagecomp],
                         api_name='seg')
        result_image.change(create_imgcomp, [input_image, result_image], [imagecomp])
        #b2.click(create_imgcomp, [input_image, filename], [imagecomp])
        
    return demo