import streamlit as st
import requests
from openai import OpenAI
from PIL import Image
import io
import os
from datetime import datetime

def preprocess_image(uploaded_file):
    """
    Preprocess the image to meet OpenAI's requirements:
    - Convert to PNG
    - Ensure file size is less than 4MB
    - Resize if necessary while maintaining aspect ratio
    """
    # Create temp directory if it doesn't exist
    if not os.path.exists("temp"):
        os.makedirs("temp")
    
    # Open and convert image to PNG
    image = Image.open(uploaded_file)
    
    # Convert to RGB if image is in RGBA mode
    if image.mode == 'RGBA':
        image = image.convert('RGB')
    
    # Calculate new dimensions while maintaining aspect ratio
    max_size = 1024
    ratio = min(max_size/image.width, max_size/image.height)
    new_size = (int(image.width*ratio), int(image.height*ratio))
    
    # Resize image if it's too large
    if image.width > max_size or image.height > max_size:
        image = image.resize(new_size, Image.Resampling.LANCZOS)
    
    # Save processed image
    temp_path = f"temp/processed_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
    image.save(temp_path, "PNG", optimize=True)
    
    # Check file size and compress if needed
    while os.path.getsize(temp_path) > 4*1024*1024:  # 4MB in bytes
        image = image.resize(
            (int(image.width*0.9), int(image.height*0.9)),
            Image.Resampling.LANCZOS
        )
        image.save(temp_path, "PNG", optimize=True)
    
    return temp_path

def save_image_from_url(image_url, index):
    """Save image from URL to local file"""
    response = requests.get(image_url)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = f"generated_variations/{timestamp}_variation_{index}.png"
    
    if not os.path.exists("generated_variations"):
        os.makedirs("generated_variations")
        
    with open(output_path, "wb") as f:
        f.write(response.content)
    return output_path

def main():
    st.title("OpenAI Image Variation Generator")
    
    # Sidebar for API key
    st.sidebar.header("Settings")
    api_key = st.sidebar.text_input("Enter OpenAI API Key", type="password")
    
    if not api_key:
        st.warning("Please enter your OpenAI API key in the sidebar to continue.")
        return
    
    # Main content
    st.write("Upload an image to generate variations using DALL-E 2")
    
    # Image upload with clear file type instructions
    st.info("Please upload a PNG, JPG, or JPEG image. The image will be automatically processed to meet OpenAI's requirements (PNG format, < 4MB).")
    uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
    
    # Control options
    col1, col2 = st.columns(2)
    with col1:
        num_variations = st.slider("Number of variations", min_value=1, max_value=4, value=1)
    with col2:
        size_options = ["1024x1024", "512x512", "256x256"]
        selected_size = st.selectbox("Image size", size_options)
    
    if uploaded_file is not None:
        try:
            # Display uploaded image
            st.subheader("Uploaded Image")
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_container_width=True)
            
            # Generate variations button
            if st.button("Generate Variations"):
                try:
                    # Preprocess and save image
                    with st.spinner("Processing image..."):
                        temp_path = preprocess_image(uploaded_file)
                    
                    # Show processed image details
                    file_size_mb = os.path.getsize(temp_path) / (1024 * 1024)
                    st.success(f"Image processed successfully! File size: {file_size_mb:.2f}MB")
                    
                    # Initialize OpenAI client
                    client = OpenAI(api_key=api_key)
                    
                    with st.spinner("Generating variations..."):
                        # Generate variations
                        response = client.images.create_variation(
                            model="dall-e-2",
                            image=open(temp_path, "rb"),
                            n=num_variations,
                            size=selected_size
                        )
                        
                        # Display generated variations
                        st.subheader("Generated Variations")
                        cols = st.columns(num_variations)
                        
                        for idx, image_data in enumerate(response.data):
                            # Save and display each variation
                            saved_path = save_image_from_url(image_data.url, idx)
                            with cols[idx]:
                                st.image(saved_path, caption=f"Variation {idx+1}", use_container_width=True)
                                with open(saved_path, "rb") as file:
                                    st.download_button(
                                        label=f"Download Variation {idx+1}",
                                        data=file,
                                        file_name=f"variation_{idx+1}.png",
                                        mime="image/png"
                                    )
                    
                    # Cleanup temporary file
                    os.remove(temp_path)
                    
                except Exception as e:
                    st.error(f"An error occurred: {str(e)}")
                    if "invalid_request_error" in str(e):
                        st.info("Please ensure your image meets OpenAI's requirements: PNG format, less than 4MB, and appropriate content.")
        
        except Exception as e:
            st.error(f"Error loading image: {str(e)}")

if __name__ == "__main__":
    main()