import streamlit as st
from openai import OpenAI
import requests
from PIL import Image
import io
import os
from datetime import datetime

def convert_to_png(image_file):
    """Convert uploaded image to PNG format and ensure it's under 4MB"""
    # Open the image using PIL
    image = Image.open(image_file)
    
    # Convert to RGBA if not already
    if image.mode in ('RGBA', 'RGB'):
        image = image.convert('RGBA')
    else:
        image = image.convert('RGB').convert('RGBA')
    
    # Save to bytes with PNG format
    byte_arr = io.BytesIO()
    image.save(byte_arr, format='PNG', optimize=True)
    byte_arr.seek(0)
    
    # Check if size is under 4MB
    if byte_arr.getbuffer().nbytes > 4 * 1024 * 1024:
        # If image is too large, resize it while maintaining aspect ratio
        while byte_arr.getbuffer().nbytes > 4 * 1024 * 1024:
            width, height = image.size
            new_width = int(width * 0.9)  # Reduce by 10%
            new_height = int(height * 0.9)
            image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
            
            byte_arr = io.BytesIO()
            image.save(byte_arr, format='PNG', optimize=True)
            byte_arr.seek(0)
    
    return byte_arr

def validate_image(uploaded_file):
    """Validate image size and format"""
    if uploaded_file.size > 4 * 1024 * 1024:  # 4MB in bytes
        return False, "File size must be less than 4MB"
    return True, "OK"

def save_uploaded_file(uploaded_file, folder="uploads"):
    """Save uploaded file to a temporary folder and return the path"""
    if not os.path.exists(folder):
        os.makedirs(folder)
    
    # Generate a timestamp for unique filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Convert to PNG and validate size
    image_bytes = convert_to_png(uploaded_file)
    
    # Save the converted PNG
    file_path = os.path.join(folder, f"{timestamp}.png")
    with open(file_path, "wb") as f:
        f.write(image_bytes.getvalue())
    
    return file_path

def download_image(url, folder="generated"):
    """Download image from URL and save it"""
    if not os.path.exists(folder):
        os.makedirs(folder)
        
    response = requests.get(url)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_path = os.path.join(folder, f"generated_image_{timestamp}.png")
    
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
        return file_path
    return None

def main():
    st.title("🎨 Image Editor with DALL-E 2")
    
    # Sidebar for API key
    with st.sidebar:
        st.header("Configuration")
        api_key = st.text_input("Enter your OpenAI API key", type="password")
        st.markdown("""
        ### How to get an API key
        1. Go to [OpenAI API Keys](https://platform.openai.com/api-keys)
        2. Create a new secret key
        3. Copy and paste it here
        """)
        
        # Size selection
        size_option = st.selectbox(
            "Select image size:",
            ["1024x1024", "512x512", "256x256"]
        )
        
        # Number of images
        num_images = st.slider("Number of images to generate", 1, 4, 1)

    # Main content
    st.markdown("""
    ### Requirements:
    - Original image and mask must be less than 4MB
    - Images will be automatically converted to PNG format
    - Images larger than 4MB will be automatically resized
    """)

    # File uploaders
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Original Image")
        original_image = st.file_uploader(
            "Upload original image", 
            type=["png", "jpg", "jpeg"]
        )
        if original_image:
            # Validate image
            valid, message = validate_image(original_image)
            if not valid:
                st.warning(f"Original image: {message}. Image will be automatically resized.")
            
            try:
                # Display image preview
                image = Image.open(original_image)
                st.image(image, caption="Original Image", use_container_width=True)
                st.caption(f"Original size: {original_image.size/1024/1024:.2f}MB")
            except Exception as e:
                st.error(f"Error loading image: {str(e)}")

    with col2:
        st.subheader("Mask Image")
        mask_image = st.file_uploader(
            "Upload mask image", 
            type=["png", "jpg", "jpeg"]
        )
        if mask_image:
            # Validate image
            valid, message = validate_image(mask_image)
            if not valid:
                st.warning(f"Mask image: {message}. Image will be automatically resized.")
            
            try:
                # Display image preview
                image = Image.open(mask_image)
                st.image(image, caption="Mask Image", use_column_width=True)
                st.caption(f"Original size: {mask_image.size/1024/1024:.2f}MB")
            except Exception as e:
                st.error(f"Error loading mask: {str(e)}")

    # Prompt input
    prompt = st.text_area(
        "Enter your prompt:",
        placeholder="Describe the changes you want to make to the image...",
        help="Be specific about what you want to add or modify in the masked area"
    )

    # Generate button
    if st.button("Generate Edited Image"):
        if not api_key:
            st.error("Please enter your OpenAI API key in the sidebar.")
            return
            
        if not original_image or not mask_image:
            st.error("Please upload both an original image and a mask image.")
            return
            
        if not prompt:
            st.error("Please enter a prompt describing the desired changes.")
            return

        try:
            with st.spinner("Processing images and generating edited version..."):
                # Save and convert uploaded files
                original_path = save_uploaded_file(original_image)
                mask_path = save_uploaded_file(mask_image)

                # Initialize OpenAI client
                client = OpenAI(api_key=api_key)

                # Make the API call
                response = client.images.edit(
                    model="dall-e-2",
                    image=open(original_path, "rb"),
                    mask=open(mask_path, "rb"),
                    prompt=prompt,
                    n=num_images,
                    size=size_option
                )

                # Display results
                st.subheader("Generated Images")
                cols = st.columns(num_images)
                
                for idx, image_data in enumerate(response.data):
                    # Download and save the generated image
                    saved_image_path = download_image(image_data.url)
                    
                    if saved_image_path:
                        with cols[idx]:
                            st.image(saved_image_path, caption=f"Generated Image {idx+1}")
                            
                            # Add download button for each image
                            with open(saved_image_path, "rb") as file:
                                st.download_button(
                                    label=f"Download Image {idx+1}",
                                    data=file,
                                    file_name=f"edited_image_{idx+1}.png",
                                    mime="image/png"
                                )

                # Cleanup temporary files
                for path in [original_path, mask_path]:
                    if os.path.exists(path):
                        os.remove(path)

        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            
        finally:
            # Cleanup temporary files
            for path in [original_path, mask_path]:
                if 'path' in locals() and os.path.exists(path):
                    try:
                        os.remove(path)
                    except Exception:
                        pass

if __name__ == "__main__":
    main()